# -*- coding: utf-8 -*-
########################################################################
# 
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
# 
########################################################################
 

import gc
import sys
from shapely.geometry import Point, LineString

from genregion.region import geometry as mygeo
from genregion.region import error
from genregion.generate.gen import cluster
from genregion.generate.gen import generator
from genregion.generate.gen import regionalg
from genregion.generate.gen import segspliter

def rect_cluster(points, width=40):
    """The hierarchical cluster algorithm.

    Args:
        points (list): A list of mygeo.Point objects ready to be clustered.
        width (int, optional): Clustering threshold. Defaults to 40.

    Returns:
        list: A list of clusters.
    """
    alg = cluster.HCAlgorithm(width)
    return alg.run(points)


def rect_classify(points, clusters, width=40):
    """Classify extra points to corresponding clusters.

    Args:
        points (list): A list of mygeo.Point objects.
        clusters (list): A list of hcclusters.
        width (int, optional): Threshold. Defaults to 40.

    Returns:
        list: A list of points which have not been classified.
    """
    alg = cluster.Classifier(clusters, width)
    return alg.run(points)


def region_filter(regions, grid_size=1024, thread_num=4):
    """Filter out subregions based from original regions.

    Args:
        regions (list): A list of regions that might contain subregions.
        grid_size (int, optional): Used to spilt regions to different threads. Defaults to 1024.
        thread_num (int, optional): The number of threads. Defaults to 4.

    Returns:
        list: A list of filtered regions.
    """
    ret = None
    rf = regionalg.RegionFilter(regions)
    if thread_num <= 1:
        ret = rf.run(grid_size)
    else:
        ret = rf.multi_run(grid_size, thread_num)
    return ret


def split_segs_by_segs(segs, grid_size=1024):
    """Split segments to prepare for region generation.

    We first calculate every intersection between each pair of segments.
    Then, we split those segments based on these intersections.
    The output of our function should be segments that connect to others only at its two end nodes.

    Args:
        segs (list): A list of segments.
        grid_size (int, optional): The size of the grid. Defaults to 1024.

    Returns:
        list: A list of segments.
    """
    ss = segspliter.SegSpliter(segs, grid_size)
    return ss.run()


def splitedsegs_2_regions(segs):
    """Generate regions based on processed segments.

    Args:
        segs (list): A list of segments that only connect to others at their two end points.

    Returns:
        list: A list of regions generated by input segments.
    """
    rg = generator.RegionGenerator(segs)
    return rg.run()


def merge_regions(regions, grid_size=1024, \
                  area_thres=10000, width_thres=20):
    """Merge raw regions with our regions. We temporarily stop this function.

    Args:
        regions (list): A list of regions.
        grid_size (int, optional): The size of the grid. Defaults to 1024.
        width_thres (int, optional): The width threshold to filter out narrow regions. Defaults to 20.

    Returns:
        list: A list of regions that incorporates with original regions.
    """
    raw_regions = None
    rm = regionalg.RegionMerger(regions, raw_regions)
    segs = rm.run(grid_size, area_thres, width_thres)
    rg = generator.RegionGenerator(segs)
    return rg.run()


def __cluster_points(segments, grid_size, clust_width):
    """Run the clustering algorithm on all points of all segments.

    Args:
        segments (list): A list of original segments.
        grid_size (int): The size of the grid.
        clust_width (int): The threshold of the cluster.

    Returns:
        list: A list of clusters.
    """
    if isinstance(segments, str):    # Read segments from the file.
        error.debug("Loading segments from the file: %s" % (segments))
        seg_points = generator.segments_to_cluster_points(mygeo.gen_segments(segments))
    else:
        error.debug("The number of segments: %d" % (len(segments)))
        seg_points = generator.segments_to_cluster_points(generate_segments(segments))
    error.debug("The number of different segment endpoints: %d" % (len(seg_points)))
    clusters = rect_cluster(seg_points, clust_width)
    del seg_points
    gc.collect()
    error.debug("The number of clusters: %d" % (len(clusters)))
    return clusters

def generate_segments(original_segments):
    """Convert a list of shapely linestrings, shapely point tuples, or coordinate pairs
        to a list of mygeo.Segment objects.

    Args:
        original_segments (list): A list of representations of segments.

    Raises:
        error.RegionError: Empty object.
        error.RegionError: Unsupported type of input.

    Returns:
        list: A list pf mygeo.Segment objects.
    """
    if original_segments is None:
        raise error.RegionError("Empty object")
    segs = []
    if isinstance(original_segments[0], tuple):
        for pt1, pt2 in original_segments:
            if isinstance(pt1, Point) and isinstance(pt2, Point):
                seg = mygeo.Segment(mygeo.Point(pt1.x, pt1.y), mygeo.Point(pt2.x, pt2.y))
            elif isinstance(pt1[0], float) or isinstance(pt1[0], int):
                seg = mygeo.Segment(mygeo.Point(pt1[0], pt1[1]), mygeo.Point(pt2[0], pt2[1]))
            else:
                raise error.RegionError("Unsupported type of input.")
            segs.append(seg)
    elif isinstance(original_segments[0], LineString):
        for linestr in original_segments:
            (x1, y1), (x2, y2) = list(linestr.coords)
            seg = mygeo.Segment(mygeo.Point(x1, y1), mygeo.Point(x2, y2))
            segs.append(seg)
    else:
        raise error.RegionError("Unsupported type of input.")
    return segs


def generate_regions(segments, grid_size=1024, \
                     area_thres=10000, width_thres=20, \
                     clust_width=25, point_precision=0):
    """Generate regions based on input segments.

    To represent original segments, we basically support 4 types of input:
        1. The file path that stores every segment of the road network:
            The file should only contain point coordinates at each line.
            Ex: 255834.51327326 3323376.71868603,260889.23516149 3321991.45674967
            Note that there is a comma between two points.
        2. A list of tuples of shapely points:
            Ex: [(Point_1, Point_2), (Point_3, Point_4), ...]
        3. A list of tuples of point coordinates:
            Ex: [((x1, y1), (x2, y2)), ((x3, y3), (x4, y4)), ...]
        4. A list of shapely LineString objects:
            Ex: [LineString_1, LineString_2]
    Important note: 
        All forms of segment representations must be a list of single segment.

    Args:
        segments (list): See details above.
        grid_size (int, optional): Use to build a grid dictionary for searching. Defaults to 1024.
        area_thres (int, optional): The minimum area of a generated region. Defaults to 10000.
        width_thres (int, optional): The minimum ratio of area/perimeter. Defaults to 20.
        clust_width (int, optional): The threshold that helps construct the cluster.
        point_precision (int, optional): The precision of the point object while processing.

    Returns:
        list: A list of generated regions.
    """
    mygeo.Point.set_precision(point_precision)
    clusters = __cluster_points(segments, grid_size, clust_width)
    pointmap = generator.clusters_to_pointmap(clusters)
    del clusters
    gc.collect()
    error.debug("The pointmap is finished.")
    if isinstance(segments, str):
        segs = generator.simplify_by_pointmap(mygeo.gen_segments(segments), pointmap)
    else:
        segs = generator.simplify_by_pointmap(generate_segments(segments), pointmap)
    del pointmap
    gc.collect()
    error.debug("The number of original segments: %d" % (len(segs)))
    segs = split_segs_by_segs(segs, grid_size)
    gc.collect()
    error.debug("The number of splitted segments.: %d" % (len(segs)))
    regs = splitedsegs_2_regions(segs)
    del segs
    gc.collect()
    error.debug("The number of the regions generated at the first step: %d" % (len(regs)))
    regs = merge_regions(regs, grid_size, area_thres, width_thres)
    gc.collect()
    error.debug("The number of regions after merging those small ones: %d" % (len(regs)))
    regs = region_filter(regs, grid_size)
    gc.collect()
    error.debug("The number of regions after filter out subregions: %d" % (len(regs)))
    return [region.polygon() for region in regs]

